{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3b3ab079-532f-45ba-9b55-8ba54fd4cfbf",
   "metadata": {},
   "source": [
    "<!--Copyright © ZOMI 适用于[License](https://github.com/Infrasys-AI/AIInfra)版权许可-->\n",
    "\n",
    "# 手把手实现迷你版 Transformer \n",
    "\n",
    "Author by: lwh"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01454581-cc48-466d-8167-0f7c3d44c1b4",
   "metadata": {},
   "source": [
    "## Transformer 知识原理\n",
    "\n",
    "Transformer 是一种基于 自注意力机制（Self-Attention）的深度学习模型结构，用于处理序列数据（如文本、语音、时间序列等）。相较于传统的循环神经网络（RNN、LSTM），Transformer 有以下几个核心优势：\n",
    "\n",
    "- 并行计算能力强：抛弃了序列依赖的结构，使训练更快；\n",
    "- 长距离依赖建模能力强：自注意力机制可以直接建模任意位置间的关系；\n",
    "- 结构模块化：由多层编码器和解码器堆叠构成，易于扩展和修改。\n",
    "\n",
    "在本实验中，仅关注编码器（Encoder）部分的精简实现，其结构主要由以下模块组成：\n",
    "\n",
    "1. 词嵌入（Embedding）：将输入的 token 序列（如词索引）映射为稠密向量，形成初始特征表示。\n",
    "2. 位置编码（Positional Encoding）：由于 Transformer 完全抛弃了循环结构，因此需要手动注入位置信息。本实验使用论文提出的正余弦位置编码，使模型能够感知序列中 token 的位置信息。\n",
    "3. 自注意力机制（Self-Attention）：通过计算每个位置之间的关系（注意力权重），让模型自主学习哪些位置更重要。其本质是对序列的不同部分进行加权组合，突出关键特征。\n",
    "4. 残差连接 + 层归一化（Residual + LayerNorm）：为了解决深层网络中梯度消失和训练不稳定的问题，在注意力层和前馈网络后分别添加残差连接与 LayerNorm，增强模型表达能力。\n",
    "5. 前馈网络（Feed-Forward Network）：对每个位置的表示分别通过两层全连接网络，进一步提取特征并引入非线性表达。\n",
    "\n",
    "最终，编码器的每个模块都以“子层 → 残差连接 → 层归一化”的方式组成结构块，构成了一个可堆叠的 Transformer 编码器框架。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4849f6c9-fd5e-4fc0-af5b-3256ffb36d37",
   "metadata": {},
   "source": [
    "## Transformer 编码实现\n",
    "\n",
    "首先导入所需的 PyTorch 和数学库："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5b014cf3-d7ca-44d5-b483-e96db5128120",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import math"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f4247a4-b55c-4107-84c8-ee0e05754a07",
   "metadata": {},
   "source": [
    "与论文一致使用固定的正余弦位置编码方式。注意编码维度必须为偶数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "78d6d026-9d3c-4248-914e-f151321e0031",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sinusoidal_pos_encoding(seq_len: int, d_model: int) -> torch.Tensor:\n",
    "    if d_model % 2 != 0:\n",
    "        raise ValueError(\"d_model 必须为偶数\")\n",
    "\n",
    "    # 生成位置向量和维度向量\n",
    "    pos = torch.arange(0, seq_len).unsqueeze(1).float()        # shape: (seq_len, 1)\n",
    "    i = torch.arange(0, d_model // 2).float()                  # shape: (d_model/2,)\n",
    "\n",
    "    # 计算频率除数项\n",
    "    denom = torch.pow(10000, 2 * i / d_model)                  # shape: (d_model/2,)\n",
    "    angle = pos / denom                                        # shape: (seq_len, d_model/2)\n",
    "\n",
    "    # 初始化编码矩阵\n",
    "    pe = torch.zeros(seq_len, d_model)\n",
    "\n",
    "    # 填入 sin 和 cos\n",
    "    pe[:, 0::2] = torch.sin(angle)  # 偶数维度\n",
    "    pe[:, 1::2] = torch.cos(angle)  # 奇数维度\n",
    "\n",
    "    return pe"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b01e43f-73d6-481d-b6bc-cc0a1a6d7633",
   "metadata": {},
   "source": [
    "这个函数返回一个 `(seq_len, d_model)` 的位置编码张量，用于为输入序列添加位置信息。该位置编码函数完成后，就可以进入 Transformer 的核心：注意力机制的实现。点积注意力机制，它会使用三个输入张量：\n",
    "\n",
    "- `q`（Query）：表示当前 token 想关注什么； \n",
    "- `k`（Key）：表示序列中每个位置的“关键词”；\n",
    "- `v`（Value）：表示每个位置实际携带的信息。\n",
    "\n",
    "注意力机制会计算 `q` 与所有 `k` 的匹配程度，然后用这些权重对 `v` 进行加权平均，输出新的表示结果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a7a5acba-92c3-40b9-859c-8a1d5ae68561",
   "metadata": {},
   "outputs": [],
   "source": [
    "def scaled_dot_product_attention(q, k, v):\n",
    "\n",
    "    # 计算注意力打分矩阵：q 与 k 的转置点积，然后除以 sqrt(d_k) 进行缩放\n",
    "    scores = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))  # shape: (..., seq_len_q, seq_len_k)\n",
    "\n",
    "    # 对打分矩阵使用 softmax，得到注意力权重\n",
    "    weights = torch.softmax(scores, dim=-1)  # shape: (..., seq_len_q, seq_len_k)\n",
    "\n",
    "    # 使用注意力权重加权求和 v，得到注意力输出\n",
    "    return weights @ v  # shape: (..., seq_len_q, d_v)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a93b2821-ec75-4824-a7a4-2ffd0cee3259",
   "metadata": {},
   "source": [
    "接下来构建一个迷你版的 Transformer 编码器类 `MiniTransformerEncoder`，它包含以下核心模块：\n",
    "\n",
    "- 嵌入层（Embedding）：将输入的 token（整数索引）映射为 d_model 维的向量；\n",
    "- 注意力层（Self-Attention）：使用缩放点积注意力机制提取全局上下文信息；\n",
    "- 前馈网络（Feedforward）：对每个位置的表示进行非线性变换；\n",
    "- 层归一化（LayerNorm）：分别应用在注意力子层和前馈子层之后，配合残差连接使用，有助于训练稳定。\n",
    "\n",
    "此外，为每个位置添加了固定的 正余弦位置编码（Positional Encoding），使模型可以识别 token 的先后顺序。\n",
    "\n",
    "注意：直接对词嵌入张量 `x_embed` 与位置编码 `pe` 使用加法（`x = x_embed + pe`），这是因为：\n",
    "\n",
    "- `x_embed` 的形状是 `(1, seq_len, d_model)`（batch 维度为 1）；\n",
    "- `pe` 的形状是 `(seq_len, d_model)`；\n",
    "\n",
    "在 PyTorch 中，这两者可以自动广播（broadcasting）成相同形状，从而完成逐位置相加。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "dd1b58c5-45a6-4ab8-be71-bb61288da895",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MiniTransformerEncoder(nn.Module):\n",
    "    def __init__(self, vocab_size, d_model):\n",
    "        super().__init__()\n",
    "        self.embedding = nn.Embedding(vocab_size, d_model)  # 词嵌入层\n",
    "        self.linear_q = nn.Linear(d_model, d_model)         # Q 映射层\n",
    "        self.linear_k = nn.Linear(d_model, d_model)         # K 映射层\n",
    "        self.linear_v = nn.Linear(d_model, d_model)         # V 映射层\n",
    "        self.attn_output = nn.Linear(d_model, d_model)      # 注意力输出映射\n",
    "        self.norm1 = nn.LayerNorm(d_model)                  # 第一个 LayerNorm\n",
    "\n",
    "        self.ffn = nn.Sequential(                           # 前馈网络：两层全连接\n",
    "            nn.Linear(d_model, d_model * 2),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(d_model * 2, d_model)\n",
    "        )\n",
    "        self.norm2 = nn.LayerNorm(d_model)                  # 第二个 LayerNorm\n",
    "\n",
    "    def forward(self, x):\n",
    "        seq_len = x.size(1)\n",
    "        x_embed = self.embedding(x)                         # 获取词向量表示\n",
    "        pe = sinusoidal_pos_encoding(seq_len, x_embed.size(-1))\n",
    "        x = x_embed + pe                                    # 添加位置编码\n",
    "\n",
    "        # 自注意力子层\n",
    "        q = self.linear_q(x)\n",
    "        k = self.linear_k(x)\n",
    "        v = self.linear_v(x)\n",
    "        attn = scaled_dot_product_attention(q, k, v)\n",
    "        x = self.norm1(x + self.attn_output(attn))          # 残差连接 + LayerNorm\n",
    "\n",
    "        # 前馈子层\n",
    "        ff = self.ffn(x)\n",
    "        x = self.norm2(x + ff)                              # 残差连接 + LayerNorm\n",
    "        return x\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3269b3af-2960-4be8-ba8b-50797095f0aa",
   "metadata": {},
   "source": [
    "这一结构就是一个基本的 Transformer 编码器块，具备捕捉上下文、感知位置、非线性变换的能力。\n",
    "\n",
    "为了测试模型输出，用一个假输入 [3, 1, 7] 来运行前向传播："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5073e145-cc5f-4ab9-bae7-ae4ea6ce4641",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== 输出 ===\n",
      "[[[-1.5413152   1.6801366   0.05804453  0.24212095 -1.0774754\n",
      "    0.02384599 -0.12485133 -0.6640581  -0.88564116  1.0022273\n",
      "   -1.7447253  -0.11834528 -0.32912147  0.8888434   1.3979812\n",
      "    1.1923339 ]\n",
      "  [ 0.7319907   0.89123964  0.20003487 -0.07367807 -1.0843173\n",
      "    0.2291747  -1.023271    0.09876721 -0.5554284   0.05615921\n",
      "   -0.26637512 -1.0664971   0.01956824  1.8496377  -1.9267299\n",
      "    1.9197246 ]\n",
      "  [ 1.0861049  -0.38979262  1.8310152   0.47594324 -0.07559342\n",
      "   -0.44681358 -1.166736    1.6186275  -1.5705074  -0.19102472\n",
      "    0.40794823  1.0515163  -1.3678914  -0.19411024 -1.1033916\n",
      "    0.03470579]]]\n",
      "=== 结束 ===\n"
     ]
    }
   ],
   "source": [
    "model = MiniTransformerEncoder(vocab_size=50, d_model=16)\n",
    "dummy_input = torch.tensor([[3, 1, 7]])  # shape: [1, 3]\n",
    "output = model(dummy_input)\n",
    "print(\"=== 输出 ===\")\n",
    "print(output.detach().numpy())\n",
    "print(\"=== 结束 ===\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6c6d8d6-45e1-4490-bce3-3f58559eb25b",
   "metadata": {},
   "source": [
    "至此，就完成了一个最小的 Transformer 编码器搭建。它结构清晰、功能完整，非常适合用作 Transformer 学习的入门代码框架。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
